{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GENRE for fairseq\n",
    "\n",
    "First make sure that you have [fairseq](https://github.com/pytorch/fairseq) installed.\n",
    "Since `fairseq` is going through breaking changes please install it from [this](https://github.com/nicola-decao/fairseq/tree/fixing_prefix_allowed_tokens_fn) fork using: \n",
    "```bash\n",
    "git clone --branch fixing_prefix_allowed_tokens_fn https://github.com/nicola-decao/fairseq\n",
    "cd fairseq\n",
    "pip install --editable ./\n",
    "```\n",
    "as described in the [fairseq repository](https://github.com/pytorch/fairseq#requirements-and-installation) since `pip install fairseq` has issues. \n",
    "\n",
    "# GENRE for transformers\n",
    "\n",
    "First make sure that you have [transformers](https://github.com/huggingface/transformers) >=4.2.0 installed. \n",
    "**NOTE: we used fairseq for all experiments in the paper. The huggingface/transformers models are obtained with a [conversion script](https://github.com/facebookresearch/GENRE/blob/main/scripts_genre/convert_bart_original_pytorch_checkpoint_to_pytorch.py).**\n",
    "\n",
    "<hr>\n",
    "\n",
    "# Datasets\n",
    "\n",
    "Use the links below to download datasets. As an alternative use [this](https://github.com/facebookresearch/GENRE/blob/main/scripts/download_all_datasets.sh) script to dowload all of them. These dataset (except BLINK data) are a pre-processed version of [Phong Le and Ivan Titov (2018)](https://arxiv.org/pdf/1804.10637.pdf) data availabe [here](https://github.com/lephong/mulrel-nel). BLINK data taken from [here](https://github.com/facebookresearch/KILT).\n",
    "\n",
    "## Entity Disambiguation (train / dev)\n",
    "- [BLINK train](http://dl.fbaipublicfiles.com/KILT/blink-train-kilt.jsonl) (9,000,000 lines, 11GiB)\n",
    "- [BLINK dev](http://dl.fbaipublicfiles.com/KILT/blink-dev-kilt.jsonl) (10,000 lines, 13MiB)\n",
    "- [AIDA-YAGO2 train](http://dl.fbaipublicfiles.com/GENRE/aida-train-kilt.jsonl) (18,448 lines, 56MiB)\n",
    "- [AIDA-YAGO2 dev](http://dl.fbaipublicfiles.com/GENRE/aida-dev-kilt.jsonl) (4,791 lines, 15MiB)\n",
    "\n",
    "## Entity Disambiguation (test)\n",
    "- [ACE2004](http://dl.fbaipublicfiles.com/GENRE/ace2004-test-kilt.jsonl) (257 lines, 850KiB)\n",
    "- [AQUAINT](http://dl.fbaipublicfiles.com/GENRE/aquaint-test-kilt.jsonl) (727 lines, 2.0MiB)\n",
    "- [AIDA-YAGO2](http://dl.fbaipublicfiles.com/GENRE/aida-test-kilt.jsonl) (4,485 lines, 14MiB)\n",
    "- [MSNBC](http://dl.fbaipublicfiles.com/GENRE/msnbc-test-kilt.jsonl) (656 lines, 1.9MiB)\n",
    "- [WNED-CWEB](http://dl.fbaipublicfiles.com/GENRE/clueweb-test-kilt.jsonl) (11,154 lines, 38MiB)\n",
    "- [WNED-WIKI](http://dl.fbaipublicfiles.com/GENRE/wiki-test-kilt.jsonl) (6,821 lines, 19MiB)\n",
    "\n",
    "## Entity Linking (train)\n",
    "- [WIKI-ABSTRACTS](http://dl.fbaipublicfiles.com/GENRE/train_data_e2eEL.tar.gz) (6,221,563 lines, 5.1GiB)\n",
    "\n",
    "## Document Retieval\n",
    "- KILT for the these datasets please follow the download instruction on the [KILT](https://github.com/facebookresearch/KILT) repository.\n",
    "\n",
    "## Pre-processing\n",
    "To pre-process a KILT formatted dataset into source and target files as expected from `fairseq` use \n",
    "```bash\n",
    "python scripts/convert_kilt_to_fairseq.py $INPUT_FILENAME $OUTPUT_FOLDER\n",
    "```\n",
    "Then, to tokenize and binarize them as expected from `fairseq` use \n",
    "```bash\n",
    "./preprocess_fairseq.sh $DATASET_PATH $MODEL_PATH\n",
    "```\n",
    "note that this requires to have `fairseq` source code downloaded in the same folder as the `genre` repository (see [here](https://github.com/facebookresearch/GENRE/blob/main/scripts/preprocess_fairseq.sh#L14)).\n",
    "\n",
    "## Trie from KILT Wikipedia titles\n",
    "\n",
    "We also release the BPE prefix tree (trie) from KILT Wikipedia titles ([kilt_titles_trie_dict.pkl](http://dl.fbaipublicfiles.com/GENRE/kilt_titles_trie_dict.pkl)) that is based on the 2019/08/01 Wikipedia dump, downloadable in its raw format [here](http://dl.fbaipublicfiles.com/BLINK/enwiki-pages-articles.xml.bz2).\n",
    "The trie contains ~5M titles and it is used to generate entites for all the KILT experiments.\n",
    "\n",
    "<hr>\n",
    "\n",
    "# Example: Entity Disambiguation\n",
    "Download one of the pre-trained models:\n",
    "\n",
    "| Training Dataset | [pytorch / fairseq](https://github.com/pytorch/fairseq)   | [huggingface / transformers](https://github.com/huggingface/transformers) |\n",
    "| -------- | -------- | ----------- |\n",
    "| BLINK | [fairseq_entity_disambiguation_blink](http://dl.fbaipublicfiles.com/GENRE/fairseq_entity_disambiguation_blink.tar.gz)|[hf_entity_disambiguation_blink](http://dl.fbaipublicfiles.com/GENRE/hf_entity_disambiguation_blink.tar.gz)|\n",
    "| BLINK + AidaYago2 | [fairseq_entity_disambiguation_aidayago](http://dl.fbaipublicfiles.com/GENRE/fairseq_entity_disambiguation_aidayago.tar.gz)|[hf_entity_disambiguation_aidayago](http://dl.fbaipublicfiles.com/GENRE/hf_entity_disambiguation_aidayago.tar.gz)|\n",
    "\n",
    "as well as the prefix tree from KILT Wikipedia titles ([kilt_titles_trie_dict.pkl](http://dl.fbaipublicfiles.com/GENRE/kilt_titles_trie_dict.pkl)).\n",
    "\n",
    "Then load the trie and define the function to apply the constraints with the entities trie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# OPTIONAL:\n",
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from genre.trie import Trie\n",
    "\n",
    "# load the prefix tree (trie)\n",
    "with open(\"../data/kilt_titles_trie_dict.pkl\", \"rb\") as f:\n",
    "    trie = Trie.load_from_dict(pickle.load(f))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for pytorch/fairseq\n",
    "from genre.fairseq_model import GENRE\n",
    "model = GENRE.from_pretrained(\"../models/fairseq_entity_disambiguation_aidayago\").eval()\n",
    "\n",
    "# for huggingface/transformers\n",
    "# from genre.hf_model import GENRE\n",
    "# model = GENRE.from_pretrained(\"../models/hf_entity_disambiguation_aidayago\").eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and simply use `.sample` to make predictions constraining using `prefix_allowed_tokens_fn`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'Germany', 'score': tensor(-0.1856)},\n",
       "  {'text': 'Germans', 'score': tensor(-0.5461)},\n",
       "  {'text': 'German Empire', 'score': tensor(-2.1858)},\n",
       "  {'text': 'Nazi Germany', 'score': tensor(-2.4682)},\n",
       "  {'text': 'France', 'score': tensor(-4.2070)}]]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences = [\"Einstein was a [START_ENT] German [END_ENT] physicist.\"]\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example: Document Retieval\n",
    "Download one of the pre-trained models:\n",
    "\n",
    "| Training Dataset | [pytorch / fairseq](https://github.com/pytorch/fairseq)   | [huggingface / transformers](https://github.com/huggingface/transformers) |\n",
    "| -------- | -------- | ----------- |\n",
    "| [KILT](https://github.com/facebookresearch/KILT) | [fairseq_wikipage_retrieval](http://dl.fbaipublicfiles.com/GENRE/fairseq_wikipage_retrieval.tar.gz)|[hf_wikipage_retrieval](http://dl.fbaipublicfiles.com/GENRE/hf_wikipage_retrieval.tar.gz)|\n",
    "\n",
    "Then, load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for pytorch/fairseq\n",
    "from genre.fairseq_model import GENRE\n",
    "model = GENRE.from_pretrained(\"../models/fairseq_wikipage_retrieval\").eval()\n",
    "\n",
    "# for huggingface/transformers\n",
    "# from genre.hf_model import GENRE\n",
    "# model = GENRE.from_pretrained(\"../models/hf_wikipage_retrieval\").eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and simply use `.sample` to make predictions constraining using `prefix_allowed_tokens_fn`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'Albert Einstein', 'score': tensor(-0.0708)},\n",
       "  {'text': 'Werner Bruschke', 'score': tensor(-1.5357)},\n",
       "  {'text': 'Werner von Habsburg', 'score': tensor(-1.8696)},\n",
       "  {'text': 'Werner von Moltke', 'score': tensor(-2.2318)},\n",
       "  {'text': 'Werner von Eichstedt', 'score': tensor(-3.0177)}]]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences = [\"Einstein was a German physicist.\"]\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example: End-to-End Entity Linking\n",
    "\n",
    "Download one of the pre-trained models:\n",
    "\n",
    "| Training Dataset | [pytorch / fairseq](https://github.com/pytorch/fairseq)   | [huggingface / transformers](https://github.com/huggingface/transformers) |\n",
    "| -------- | -------- | ----------- |\n",
    "| WIKIPEDIA | [fairseq_e2e_entity_linking_wiki_abs](http://dl.fbaipublicfiles.com/GENRE/fairseq_e2e_entity_linking_wiki_abs.tar.gz)|[hf_e2e_entity_linking_wiki_abs](http://dl.fbaipublicfiles.com/GENRE/hf_e2e_entity_linking_wiki_abs.tar.gz)|\n",
    "| WIKIPEDIA + AidaYago2 | [fairseq_e2e_entity_linking_aidayago](http://dl.fbaipublicfiles.com/GENRE/fairseq_e2e_entity_linking_aidayago.tar.gz)|[hf_e2e_entity_linking_aidayago](http://dl.fbaipublicfiles.com/GENRE/hf_e2e_entity_linking_aidayago.tar.gz)|\n",
    "\n",
    "Then, load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for pytorch/fairseq\n",
    "from genre.fairseq_model import GENRE\n",
    "from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_fairseq as get_prefix_allowed_tokens_fn\n",
    "from genre.utils import get_entity_spans_fairseq as get_entity_spans\n",
    "model = GENRE.from_pretrained(\"../models/fairseq_e2e_entity_linking_aidayago\").eval()\n",
    "\n",
    "# for huggingface/transformers\n",
    "# from genre.hf_model import GENRE\n",
    "# from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_hf as get_prefix_allowed_tokens_fn\n",
    "# from genre.utils import get_entity_spans_hf as get_entity_spans\n",
    "# model = GENRE.from_pretrained(\"../models/hf_e2e_entity_linking_aidayago\").eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and \n",
    "1. get the `prefix_allowed_tokens_fn` with the only constraints to annotate the original sentence (i.e., no other constrains on mention nor candidates)\n",
    "2. use `.sample` to make predictions constraining using `prefix_allowed_tokens_fn`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physiology or Medicine ].',\n",
       "   'score': tensor(-0.9068)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physiology or Medicine ] {. } [ Nobel Prize in Physiology or Medicine ]',\n",
       "   'score': tensor(-0.9301)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physiology or Medicine ] {. } [ Albert Einstein ]',\n",
       "   'score': tensor(-0.9943)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physiology or Physiology ].',\n",
       "   'score': tensor(-1.0778)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physiology or Medicine ] {. } [ Ernest Einstein ]',\n",
       "   'score': tensor(-1.1164)}]]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences = [\"In 1921, Einstein received a Nobel Prize.\"]\n",
    "\n",
    "prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(model, sentences)\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can constrain the mentions with a prefix tree (no constrains on candidates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'In 1921, { Einstein } [ Albert Einstein ] received a Nobel Prize.',\n",
       "   'score': tensor(-1.4009)},\n",
       "  {'text': 'In 1921, { Einstein } [ Einstein (crater) ] received a Nobel Prize.',\n",
       "   'score': tensor(-1.6665)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Albert Einstein ] received a Nobel Prize.',\n",
       "   'score': tensor(-1.7498)},\n",
       "  {'text': 'In 1921, { Einstein } [ Ernest Einstein ] received a Nobel Prize.',\n",
       "   'score': tensor(-1.8327)},\n",
       "  {'text': 'In 1921, { Einstein } [ Max Einstein ] received a Nobel Prize.',\n",
       "   'score': tensor(-1.8757)}]]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(\n",
    "    model,\n",
    "    sentences,\n",
    "    mention_trie=Trie([\n",
    "        model.encode(e)[1:].tolist()\n",
    "        for e in [\" Einstein\"]\n",
    "    ])\n",
    ")\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can constrain the candidates with a prefix tree (no constrains on mentions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physics ].',\n",
       "   'score': tensor(-0.8925)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize. } [ Nobel Prize in Physics ]',\n",
       "   'score': tensor(-0.8990)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel } [ Nobel Prize in Physics ] { Prize } [ Nobel Prize in Physics ].',\n",
       "   'score': tensor(-0.9330)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physics ] {. } [ Nobel Prize in Physics ]',\n",
       "   'score': tensor(-0.9781)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physics ] {. } [ Albert Einstein ]',\n",
       "   'score': tensor(-0.9854)}]]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(\n",
    "    model,\n",
    "    sentences,\n",
    "    candidates_trie=Trie([\n",
    "        model.encode(\" }} [ {} ]\".format(e))[1:].tolist()\n",
    "        for e in [\"Albert Einstein\", \"Nobel Prize in Physics\", \"NIL\"]\n",
    "    ])\n",
    ")\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can constrain the candidate sets given a mention (no constrains on mentions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'In 1921, { Einstein } [ Einstein ] received a { Nobel } [ Nobel Prize in Physics ] Prize.',\n",
       "   'score': tensor(-1.5417)},\n",
       "  {'text': 'In 1921, { Einstein } [ Einstein ] received a Nobel Prize.',\n",
       "   'score': tensor(-2.1319)},\n",
       "  {'text': 'In 1921, { Einstein } [ Einstein ] received a { Nobel } [ Nobel Prize in Physics ] { Prize } [ NIL ].',\n",
       "   'score': tensor(-2.2816)},\n",
       "  {'text': 'In 1921, { Einstein } [ Einstein ] received a { Nobel } [ Nobel Prize in Physics ] { Prize. } [ NIL ]',\n",
       "   'score': tensor(-2.3914)},\n",
       "  {'text': 'In 1921, { Einstein } [ Einstein ] received a { Nobel Prize. } [ NIL ]',\n",
       "   'score': tensor(-2.9078)}]]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(\n",
    "    model,\n",
    "    sentences,\n",
    "    mention_to_candidates_dict={\n",
    "        \"Einstein\": [\"Einstein\"],\n",
    "        \"Nobel\": [\"Nobel Prize in Physics\"],\n",
    "    }\n",
    ")\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A combiation of these constraints is also possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'In 1921, { Einstein } [ Albert Einstein ] received a { Nobel Prize } [ Nobel Prize in Physics ].',\n",
       "   'score': tensor(-0.8925)},\n",
       "  {'text': 'In 1921, { Einstein } [ Einstein (surname) ] received a { Nobel Prize } [ Nobel Prize in Physics ].',\n",
       "   'score': tensor(-1.3275)},\n",
       "  {'text': 'In 1921, { Einstein } [ Albert Einstein ] received a Nobel Prize.',\n",
       "   'score': tensor(-1.4009)},\n",
       "  {'text': 'In 1921, Einstein received a { Nobel Prize } [ Nobel Prize in Physics ].',\n",
       "   'score': tensor(-1.8266)},\n",
       "  {'text': 'In 1921, Einstein received a Nobel Prize.',\n",
       "   'score': tensor(-3.4495)}]]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(\n",
    "    model,\n",
    "    sentences,\n",
    "    mention_trie=Trie([\n",
    "        model.encode(\" {}\".format(e))[1:].tolist()\n",
    "        for e in [\"Einstein\", \"Nobel Prize\"]\n",
    "    ]),\n",
    "    mention_to_candidates_dict={\n",
    "        \"Einstein\": [\"Albert Einstein\", \"Einstein (surname)\"],\n",
    "        \"Nobel Prize\": [\"Nobel Prize in Physics\", \"Nobel Prize in Medicine\"],\n",
    "    }\n",
    ")\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, you can also use some functions from `genre.utils` that wraps pre- and post-processing of strings (e.g., normalization and outputs the character offsets and length of the mentions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(9, 8, 'Albert_Einstein'), (29, 11, 'Nobel_Prize_in_Physics')]]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_entity_spans(\n",
    "    model,\n",
    "    sentences,\n",
    "    mention_trie=Trie([\n",
    "        model.encode(\" {}\".format(e))[1:].tolist()\n",
    "        for e in [\"Einstein\", \"Nobel Prize\"]\n",
    "    ]),\n",
    "    mention_to_candidates_dict={\n",
    "        \"Einstein\": [\"Albert Einstein\", \"Einstein (surname)\"],\n",
    "        \"Nobel Prize\": [\"Nobel Prize in Physics\", \"Nobel Prize in Medicine\"],\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and with the `entity_spans` generate Markdown with clickable links"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "In 1921, [Einstein](https://en.wikipedia.org/wiki/Albert_Einstein) received a [Nobel Prize](https://en.wikipedia.org/wiki/Nobel_Prize_in_Physics)."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from genre.utils import get_markdown\n",
    "from IPython.display import Markdown\n",
    "\n",
    "entity_spans = get_entity_spans(\n",
    "    model,\n",
    "    sentences,\n",
    "    mention_trie=Trie([\n",
    "        model.encode(\" {}\".format(e))[1:].tolist()\n",
    "        for e in [\"Einstein\", \"Nobel Prize\"]\n",
    "    ]),\n",
    "    mention_to_candidates_dict={\n",
    "        \"Einstein\": [\"Albert Einstein\", \"Einstein (surname)\"],\n",
    "        \"Nobel Prize\": [\"Nobel Prize in Physics\", \"Nobel Prize in Medicine\"],\n",
    "    }\n",
    ")\n",
    "\n",
    "Markdown(get_markdown(sentences, entity_spans)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom End-to-End Entity Linking evaluation\n",
    "\n",
    "We have some useful function to evaluate End-to-End Entity Linking predictions. Let's suppose we have a `Dict[str, str]` with document IDs and text as well as the gold entites spans as a `List[Tuple[str, int, int, str]]` containing documentID, start offset, length and entity title respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "documents = {\n",
    "    \"id_0\": \"In 1921, Einstein received a Nobel Prize.\",\n",
    "    \"id_1\": \"Armstrong was the first man on the Moon.\",\n",
    "}\n",
    "\n",
    "gold_entities = [\n",
    "    (\"id_0\", 3, 4, \"1921\"),\n",
    "    (\"id_0\", 9, 8, 'Albert_Einstein'),\n",
    "    (\"id_0\", 29, 11, 'Nobel_Prize_in_Physics'),\n",
    "    (\"id_1\", 0, 9, 'Neil_Armstrong'),\n",
    "    (\"id_1\", 35, 4, 'Moon'),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can get preditions and using `get_entity_spans_fairseq` to have spans. `guess_entities` is then a `List[List[Tuple[int, int, str]]]` containing for each document, a list of entity spans (without the document ID). We further need to add documentIDs to `guess_entities` and remove the nested list to be compatible with `gold_entities`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "guess_entities = get_entity_spans(\n",
    "    model,\n",
    "    list(documents.values()),\n",
    ")\n",
    "\n",
    "guess_entities = [\n",
    "    (k,) + x\n",
    "    for k, e in zip(documents.keys(), guess_entities)\n",
    "    for x in e\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can import all functions from `genre.utils` to compute scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "micro_p=0.2500 micro_r=0.4000, micro_f1=0.3077, macro_p=0.2500, macro_r=0.4167, macro_f1=0.3095\n"
     ]
    }
   ],
   "source": [
    "from genre.utils import (\n",
    "    get_micro_precision,\n",
    "    get_micro_recall,\n",
    "    get_micro_f1,\n",
    "    get_macro_precision,\n",
    "    get_macro_recall,\n",
    "    get_macro_f1,\n",
    ")\n",
    "\n",
    "micro_p = get_micro_precision(guess_entities, gold_entities)\n",
    "micro_r = get_micro_recall(guess_entities, gold_entities)\n",
    "micro_f1 = get_micro_f1(guess_entities, gold_entities)\n",
    "macro_p = get_macro_precision(guess_entities, gold_entities)\n",
    "macro_r = get_macro_recall(guess_entities, gold_entities)\n",
    "macro_f1 = get_macro_f1(guess_entities, gold_entities)\n",
    "\n",
    "print(\n",
    "   \"micro_p={:.4f} micro_r={:.4f}, micro_f1={:.4f}, macro_p={:.4f}, macro_r={:.4f}, macro_f1={:.4f}\".format(\n",
    "       micro_p, micro_r, micro_f1, macro_p, macro_r, macro_f1\n",
    "   )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert 2 / 8 == micro_p\n",
    "assert 2 / 5 == micro_r\n",
    "assert 2 * (2 / 8 * 2 / 5) / (2 / 8 + 2 / 5) == micro_f1\n",
    "assert (1 / 4 + 1 / 4) / 2 == macro_p\n",
    "assert (1 / 3 + 1 / 2) / 2 == macro_r\n",
    "assert (2 * (1 / 4 * 1 / 3) / (1 / 4 + 1 / 3) + 2 * (1 / 4 * 1 / 2) / (1 / 4 + 1 / 2)) / 2 == macro_f1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
